--- external/gsn/models/generator.py	2023-03-28 16:58:33.300084813 +0000
+++ external_reference/gsn/models/generator.py	2023-03-28 16:00:32.533475741 +0000
@@ -5,10 +5,11 @@
 import torch
 from torch import nn
 
-from .layers import *
-from utils.utils import instantiate_from_config
-from .nerf_utils import get_sample_points, volume_render_radiance_field, sample_pdf_2
-
+# remove relative imports -- causes persistence error
+from external.gsn.models.layers import *
+from external.gsn.models.nerf_utils import get_sample_points, volume_render_radiance_field, sample_pdf_2
+from torch_utils import persistence
+import dnnlib
 
 class StyleGenerator2D(nn.Module):
     def __init__(self, out_res, out_ch, z_dim, ch_mul=1, ch_max=512, skip_conn=True):
@@ -140,7 +141,7 @@
             skip = self.out_rgb(out, z[i])
         return skip
 
-
+@persistence.persistent_class
 class NerfStyleGenerator(nn.Module):
     """NeRF MLP with style modulation.
 
@@ -175,17 +176,19 @@
 
         self.skips = skips
 
-        self.from_coords = PositionalEncoding(in_dim=3, frequency_bands=omega_coord)
-        self.from_dirs = PositionalEncoding(in_dim=3, frequency_bands=omega_dir)
+        self.shift_y = 1 # shift_y=1 prevents discontinuity on ground plane
+        viewdir_dim = 0 # ignore view directions
+        coord_dim = 3 # xyz
+
         self.n_layers = n_layers
 
         self.layers = nn.ModuleList(
-            [ModulationLinear(in_channel=self.from_coords.out_dim, out_channel=channels, z_dim=z_dim)]
+            [ModulationLinear(in_channel=coord_dim, out_channel=channels, z_dim=z_dim)]
         )
 
         for i in range(1, n_layers):
             if i in skips:
-                in_channels = channels + self.from_coords.out_dim
+                in_channels = channels + coord_dim
             else:
                 in_channels = channels
             self.layers.append(ModulationLinear(in_channel=in_channels, out_channel=channels, z_dim=z_dim))
@@ -195,7 +198,7 @@
         )
         self.fc_feat = ModulationLinear(in_channel=channels, out_channel=channels, z_dim=z_dim)
         self.fc_viewdir = ModulationLinear(
-            in_channel=channels + self.from_dirs.out_dim, out_channel=channels, z_dim=z_dim
+            in_channel=channels + viewdir_dim, out_channel=channels, z_dim=z_dim
         )
         self.fc_out = ModulationLinear(
             in_channel=channels, out_channel=out_channel, z_dim=z_dim, demodulate=False, activate=False
@@ -214,6 +217,14 @@
             z = [z[:, i] for i in range(n_latents)]
         return z
 
+    def extract_height(self, coords):
+         # CHANGED: removed positional encoding since it causes repeating
+         # patterns, and return height above the ground plane
+        encoding = coords.clone()
+        encoding[:, 0::3] = 0
+        encoding[:, 2::3] = 0
+        return encoding
+
     def forward(self, z, coords, viewdirs=None):
         """Forward pass.
 
@@ -234,7 +245,8 @@
             Occupancy values of shape [B, 1].
 
         """
-        coords = self.from_coords(coords)
+        coords[:, 1] += self.shift_y
+        coords = self.extract_height(coords)
         z = self.process_latents(z)
 
         h = coords
@@ -251,15 +263,14 @@
 
         h = self.fc_feat(h, z[i + 2])
 
-        viewdirs = self.from_dirs(viewdirs)
-        h = torch.cat([h, viewdirs], dim=-1)
+        # viewdirs = self.from_dirs(viewdirs)
+        # h = torch.cat([h, viewdirs], dim=-1)
 
         h = self.fc_viewdir(h, z[i + 3])
         out = self.fc_out(h, z[i + 4])
 
         return out, alpha
 
-
 class NerfSimpleGenerator(nn.Module):
     """NeRF MLP with with standard latent concatenation.
 
@@ -420,7 +431,8 @@
         rgb = torch.sigmoid(rgb)
         return rgb
 
-
+# CHANGED: keep track of inference feature resolution and adjust feature sampling
+@persistence.persistent_class
 class SceneGenerator(nn.Module):
     """NeRF scene generator.
 
@@ -466,6 +478,7 @@
         local_coordinates=True,
         hierarchical_sampling=False,
         density_bias=0,
+        use_disp=True,
         **kwargs
     ):
         super().__init__()
@@ -478,9 +491,12 @@
         self.local_coordinates = local_coordinates
         self.hierarchical_sampling = hierarchical_sampling
         self.density_bias = density_bias
-        self.out_dim = nerf_mlp_config.params.out_channel
+        self.out_dim = nerf_mlp_config.out_channel
+        self.use_disp = use_disp
+
+        self.local_generator = dnnlib.util.construct_class_by_name(**nerf_mlp_config)
+        self.inference_feat_res = None
 
-        self.local_generator = instantiate_from_config(nerf_mlp_config)
 
     def get_local_coordinates(self, global_coords, local_grid_length, preserve_y=True):
         local_coords = global_coords.clone()
@@ -490,7 +506,16 @@
         # scale so that each grid cell in the local_latent grid is 1x1 in size
         local_coords = local_coords * local_grid_length
         # subtract integer from each coordinate so that they are all in range [0, 1]
-        local_coords = local_coords - (local_coords - 0.5).round()
+        if self.inference_feat_res is not None:
+            offset = (self.inference_feat_res - self.global_feat_res) // 2
+            local_coords = local_coords + offset
+            # now local_coords should be in range [0, inference_feat_res]
+            local_coords = local_coords - local_coords.clip(0, self.inference_feat_res).floor()
+        else:
+            # clip st. input if different if local coords goes off grid
+            local_coords = local_coords - local_coords.clip(0, local_grid_length).floor()
+            # local_coords = local_coords - (local_coords - 0.5).round()
+
         # return to [-1, 1] scale
         local_coords = (local_coords * 2) - 1
 
@@ -516,6 +541,11 @@
 
         samples_per_ray = xyz.shape[2]
 
+        if self.inference_feat_res is not None:
+            # adjust the coordinates for grid sampling the local latents
+            assert(self.inference_feat_res == local_latents.shape[-1])
+            xyz = xyz * self.global_feat_res / self.inference_feat_res
+
         # all samples get the most detailed latent codes
         sampled_local_latents = nn.functional.grid_sample(
             input=local_latents,
@@ -542,6 +572,12 @@
             # this tries to get all input coordinates to lie within [-1, 1]
             xyz = xyz / (self.coordinate_scale / 2)
 
+        if local_latents.shape[-1] != self.global_feat_res:
+            # coordinate adjustment for inference time sampling
+            self.inference_feat_res = local_latents.shape[-1]
+        else:
+            self.inference_feat_res = None
+
         B, n_samples, samples_per_ray, _ = xyz.shape  # n_samples = H * W
         sampled_local_latents, local_latents = self.sample_local_latents(local_latents, xyz=xyz)
 
@@ -591,10 +627,8 @@
             H, W = render_params.nerf_out_res, render_params.nerf_out_res
 
             # if using feature-NeRF, need to adjust camera intrinsics to account for lower sampling resolution
-            if self.img_res is not None:
-                downsampling_ratio = render_params.nerf_out_res / self.img_res
-            else:
-                downsampling_ratio = 1
+            # note: intrinsics K should be adjusted to match output resolution of nerf
+            downsampling_ratio = 1
             fx, fy = render_params.K[0, 0, 0] * downsampling_ratio, render_params.K[0, 1, 1] * downsampling_ratio
             xyz, viewdirs, z_vals, rd, ro = get_sample_points(
                 tform_cam2world=render_params.Rt.inverse(),
@@ -618,7 +652,7 @@
             return alpha_coarse
 
         if self.hierarchical_sampling:
-            _, _, _, weights, _, occupancy_prior = volume_render_radiance_field(
+            _, _, _, weights, _, occupancy_prior, _ = volume_render_radiance_field(
                 rgb=rgb_coarse,
                 occupancy=alpha_coarse,
                 depth_values=z_vals,
@@ -627,6 +661,7 @@
                 alpha_activation=self.alpha_activation,
                 activate_rgb=not self.feature_nerf,
                 density_bias=self.density_bias,
+                use_disp=self.use_disp
             )
 
             z_vals_fine = self.importance_sampling(z_vals, weights, render_params.samples_per_ray)
@@ -648,7 +683,8 @@
             rgb, alpha = rgb_coarse, alpha_coarse
             z_vals = z_vals
 
-        rgb, _, _, _, depth, occupancy_prior = volume_render_radiance_field(
+        # note: if use_disp=True, then the depth output is inverse depth
+        rgb, disp, acc, weights, depth, occupancy_prior, extras = volume_render_radiance_field(
             rgb=rgb,
             occupancy=alpha,
             depth_values=z_vals,
@@ -657,14 +693,20 @@
             alpha_activation=self.alpha_activation,
             activate_rgb=not self.feature_nerf,
             density_bias=self.density_bias,
+            use_disp=self.use_disp
         )
 
         out = {
             'rgb': rgb,
             'depth': depth,
+            'acc': acc,
             'Rt': render_params.Rt,
             'K': render_params.K,
             'local_latents': local_latents,
             'occupancy_prior': occupancy_prior,
+            'xyz': xyz,  # also return the sampled points
+            'weights': weights, # and weights at each point
+            'alpha': extras['alpha'], # for opacity regularization
+            'dists': extras['dists'], # for opacity regularization
         }
         return out
